{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "451d89db",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BertModel(\n",
      "  (embeddings): BertEmbeddings(\n",
      "    (word_embeddings): Embedding(21128, 768, padding_idx=0)\n",
      "    (position_embeddings): Embedding(512, 768)\n",
      "    (token_type_embeddings): Embedding(2, 768)\n",
      "    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "    (dropout): Dropout(p=0.1, inplace=False)\n",
      "  )\n",
      "  (encoder): BertEncoder(\n",
      "    (layer): ModuleList(\n",
      "      (0-11): 12 x BertLayer(\n",
      "        (attention): BertAttention(\n",
      "          (self): BertSdpaSelfAttention(\n",
      "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "            (dropout): Dropout(p=0.1, inplace=False)\n",
      "          )\n",
      "          (output): BertSelfOutput(\n",
      "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.1, inplace=False)\n",
      "          )\n",
      "        )\n",
      "        (intermediate): BertIntermediate(\n",
      "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "          (intermediate_act_fn): GELUActivation()\n",
      "        )\n",
      "        (output): BertOutput(\n",
      "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (pooler): BertPooler(\n",
      "    (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "    (activation): Tanh()\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "from transformers import BertTokenizer, BertModel\n",
    "import torch\n",
    "\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "#加载预训练模型\n",
    "pre_model = BertModel.from_pretrained(r\"E:\\nlp\\model_cache\\models--google-bert--bert-base-chinese\\snapshots\\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f\").to(DEVICE)\n",
    "print(pre_model)\n",
    "\n",
    "#定义下游任务（增量模型）\n",
    "class Classifier(torch.nn.Module):\n",
    "    def __init__(self,model):\n",
    "        super(Classifier, self).__init__()\n",
    "       \n",
    "        self.fc = torch.nn.Linear(768,2)\n",
    "        # self.softmax = torch.nn.Softmax(dim=-1)\n",
    "\n",
    "    def forward(self,input_ids,attention_mask,token_type_ids):\n",
    "        with torch.no_grad():\n",
    "            out = pre_model(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)\n",
    "\n",
    "        #增量模型参与训练\n",
    "        out = self.fc(out.last_hidden_state[:,0])\n",
    "        return out"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
